import torch
from torch import nn
import numpy as np
import Setting
import os


def Net():
    net = nn.Sequential(
        nn.Conv2d(1, 3, 3, padding = 1),
        nn.ReLU(),
        nn.Conv2d(3, 3, 3, padding = 1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(3, 8, 3, padding = 1),
        nn.ReLU(),
        nn.Conv2d(8, 8, 3, padding = 1),
        nn.ReLU(),
        nn.Flatten(), 
        nn.Linear(8 * 3 * 3, 8),
        nn.ReLU(),
        nn.Linear(8, 1),
        )
    return net

net = Net()
optimize = torch.optim.Adam(net.parameters(), lr = 0.01)
loss = nn.MSELoss()

def Load():
    if os.path.exists(Setting.weight_path):
        net.load_state_dict(torch.load(Setting.weight_path))

def Save():
    torch.save(net.state_dict(), Setting.weight_path)

def Train(X, y, steps = 100):
    X = torch.from_numpy(X)
    y = torch.from_numpy(y)
    print(net(X), y)
    for i in range(steps):
        optimize.zero_grad()
        l = loss(net(X), y)
        l.backward()
        optimize.step()

def Pre(X):
    X = torch.from_numpy(X)
    X = X.unsqueeze(0)
    X = X.unsqueeze(0)
    return float(net(X)[0, 0])

def test():
    X  = np.ones((1, 1, Setting.board_size , Setting.board_size), np.float32)
    y  = np.ones((1, 1), np.float32)
    Load()
    Train(X, y)
    print(net(torch.from_numpy(X)))
    print(net(torch.from_numpy(-X)))

if __name__ == '__main__':
    test()



